"""vgg in pytorch


[1] Karen Simonyan, Andrew Zisserman

    Very Deep Convolutional Networks for Large-Scale Image Recognition.
    https://arxiv.org/abs/1409.1556v6
"""
'''VGG11/13/16/19 in Pytorch.'''

import torch
import torch.nn as nn
from torchvision import models
#
# cfg = {
#     'A' : [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
#     'B' : [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
#     'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],
#     'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
# }
#
# class VGG(nn.Module):
#
#     def __init__(self, features, num_classes=100):
#         super().__init__()
#         self.features = features
#
#         self.classifier = nn.Sequential(
#             nn.Linear(512, 4096),
#             nn.ReLU(inplace=True),
#             nn.Dropout(),
#             nn.Linear(4096, 4096),
#             nn.ReLU(inplace=True),
#             nn.Dropout(),
#             nn.Linear(4096, num_classes)
#         )
#
#     def forward(self, x):
#         output = self.features(x)
#         output = output.view(output.size()[0], -1)
#         output = self.classifier(output)
#
#         return output
#
# def make_layers(cfg, batch_norm=False):
#     layers = []
#
#     input_channel = 3
#     for l in cfg:
#         if l == 'M':
#             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
#             continue
#
#         layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]
#
#         if batch_norm:
#             layers += [nn.BatchNorm2d(l)]
#
#         layers += [nn.ReLU(inplace=True)]
#         input_channel = l
#
#     return nn.Sequential(*layers)
#
# def vgg11_bn():
#     return VGG(make_layers(cfg['A'], batch_norm=True))
#
# def vgg13_bn():
#     return VGG(make_layers(cfg['B'], batch_norm=True))
#
# def vgg16_bn():
#     return VGG(make_layers(cfg['D'], batch_norm=True))
#
# def vgg19_bn():
#     return VGG(make_layers(cfg['E'], batch_norm=True))

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.required_grad = False


class vgg16_bn(nn.Module):
    def __init__(self, feature_extract=True, num_class=2):
        super(vgg16_bn, self).__init__()
        model = models.vgg16(pretrained=False)
        self.features = model.features
        set_parameter_requires_grad(self.features, feature_extract)
        self.avgpool = model.avgpool
        self.classifier = nn.Sequential(
            # nn.Linear(512 * 7 * 7, 512),
            # nn.ReLU(True),
            # nn.Dropout(),
            # nn.Linear(512, 128),
            # nn.ReLU(True),
            # nn.Dropout(),
            # nn.Linear(128, num_class),

            #
            nn.Linear(512 * 7 * 7, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_class)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        # 改变tensor形状，拉伸成一维
        x = x.view(x.size(0), -1)
        out = self.classifier(x)
        return out






